{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#| eval: false\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from __future__ import annotations\n",
    "from fastai.basics import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp callback.rnn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Callback for RNN training\n",
    "\n",
    "> Callback that uses the outputs of language models to add AR and TAR regularization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@docs\n",
    "class ModelResetter(Callback):\n",
    "    \"`Callback` that resets the model at each validation/training step\"\n",
    "    def before_train(self):    self.model.reset()\n",
    "    def before_validate(self): self.model.reset()\n",
    "    def after_fit(self):       self.model.reset()\n",
    "    _docs = dict(before_train=\"Reset the model before training\",\n",
    "                 before_validate=\"Reset the model before validation\",\n",
    "                 after_fit=\"Reset the model after fitting\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class RNNCallback(Callback):\n",
    "    \"Save the raw and dropped-out outputs and only keep the true output for loss computation\"\n",
    "    def after_pred(self): self.learn.pred,self.raw_out,self.out = [o[-1] if is_listy(o) else o for o in self.pred]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class RNNRegularizer(Callback):\n",
    "    \"Add AR and TAR regularization\"\n",
    "    order,run_valid = RNNCallback.order+1,False\n",
    "    def __init__(self, alpha=0., beta=0.): store_attr()\n",
    "    def after_loss(self):\n",
    "        if not self.training: return\n",
    "        if self.alpha: self.learn.loss_grad += self.alpha * self.rnn.out.float().pow(2).mean()\n",
    "        if self.beta:\n",
    "            h = self.rnn.raw_out\n",
    "            if len(h)>1: self.learn.loss_grad += self.beta * (h[:,1:] - h[:,:-1]).float().pow(2).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def rnn_cbs(alpha=0., beta=0.):\n",
    "    \"All callbacks needed for (optionally regularized) RNN training\"\n",
    "    reg = [RNNRegularizer(alpha=alpha, beta=beta)] if alpha or beta else []\n",
    "    return [ModelResetter(), RNNCallback()] + reg"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev import nbdev_export\n",
    "nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
